{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Data\" data-toc-modified-id=\"Data-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href=\"#Model\" data-toc-modified-id=\"Model-2\"><span class=\"toc-item-num\">2&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href=\"#Training\" data-toc-modified-id=\"Training-3\"><span class=\"toc-item-num\">3&nbsp;&nbsp;</span>Training</a></span></li><li><span><a href=\"#Explore-Latent-Space\" data-toc-modified-id=\"Explore-Latent-Space-4\"><span class=\"toc-item-num\">4&nbsp;&nbsp;</span>Explore Latent Space</a></span></li></ul></div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import yaml\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import functools\n",
    "from pathlib import Path\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "# Plotting\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / \"anaconda3/envs/image-processing/bin/ffmpeg\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import dcgan\n",
    "import gan_utils\n",
    "from load_data import preprocess_images\n",
    "from ds_utils.generative_utils import animate_latent_transition, gen_latent_linear, gen_latent_idx\n",
    "from ds_utils.plot_utils import plot_sample_imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = Path.home() / \"Documents/datasets\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model config\n",
    "with open('configs/dcgan_celeba_config.yaml', 'r') as f:\n",
    "    config = yaml.load(f)\n",
    "HIDDEN_DIM = config['data']['z_size']\n",
    "IMG_SHAPE = config['data']['input_shape']\n",
    "BATCH_SIZE = config['training']['batch_size']\n",
    "IMG_IS_BW = IMG_SHAPE[2] == 1\n",
    "PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE\n",
    "config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load Fashion MNIST dataset\n",
    "((X_train, y_train), (X_test, y_test)) = tf.keras.datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = preprocess_images(X_train)\n",
    "X_test = preprocess_images(X_test)\n",
    "\n",
    "print(X_train[0].shape)\n",
    "print(X_train[0].max())\n",
    "print(X_train[0].min())\n",
    "\n",
    "print(X_train.shape)\n",
    "\n",
    "assert X_train[0].shape == tuple(config['data']['input_shape'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(X_train).take(5000)\n",
    "test_ds = tf.data.Dataset.from_tensor_slices(X_test).take(256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"../\")\n",
    "from tmp_load_data import load_imgs_tfdataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 500, zipped=False)\n",
    "test_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 100, zipped=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instantiate GAN\n",
    "gan = dcgan.DCGan(IMG_SHAPE, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test generator\n",
    "generator_out = gan.generator.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM))\n",
    "generator_out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test discriminator\n",
    "discriminator_out = gan.discriminator.predict(generator_out)\n",
    "discriminator_out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test gan\n",
    "gan.gan.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM)).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot random generated image\n",
    "plt.imshow(gan.generator.predict([np.random.randn(1, HIDDEN_DIM)])[0]\n",
    "           .reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.generator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup model directory for checkpoint and tensorboard logs\n",
    "model_name = \"dcgan_celeba\"\n",
    "model_dir = Path.home() / \"Documents/models/tf_playground/gan\" / model_name\n",
    "model_dir.mkdir(exist_ok=True, parents=True)\n",
    "export_dir = model_dir / 'export'\n",
    "export_dir.mkdir(exist_ok=True)\n",
    "log_dir = model_dir / \"logs\" / datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_epochs = 1000\n",
    "gan._train(train_ds=gan.setup_dataset(train_ds),\n",
    "            validation_ds=gan.setup_dataset(test_ds),\n",
    "            nb_epochs=nb_epochs,\n",
    "            log_dir=log_dir,\n",
    "            checkpoint_dir=export_dir,\n",
    "            is_tfdataset=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export Keras model (.h5)\n",
    "gan.generator.save(str(export_dir / 'generator.h5'))\n",
    "gan.discriminator.save(str(export_dir / 'discriminator.h5'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot generator results\n",
    "plot_side = 5\n",
    "plot_sample_imgs(lambda x: gan.generator.predict(np.random.randn(plot_side*plot_side, HIDDEN_DIM)), \n",
    "                 img_shape=PLOT_IMG_SHAPE,\n",
    "                 plot_side=plot_side,\n",
    "                 cmap='gray' if IMG_IS_BW else 'jet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Explore Latent Space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_image_fun(latent_vectors):\n",
    "    img = gan.generator.predict(latent_vectors)[0].reshape(PLOT_IMG_SHAPE)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = gen_image_fun(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = Path.home() / 'Documents/videos/gan' / \"gan_celeba\"\n",
    "\n",
    "nb_samples = 10\n",
    "nb_transition_frames = 10\n",
    "nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)\n",
    "\n",
    "# random list of z vectors\n",
    "z_s = np.random.randn(nb_samples, HIDDEN_DIM)\n",
    "\n",
    "animate_latent_transition(latent_vectors=z_s, \n",
    "                         gen_image_fun=gen_image_fun,\n",
    "                         gen_latent_fun=lambda z_s, i: gen_latent_linear(z_s, i, nb_transition_frames),\n",
    "                         img_size=PLOT_IMG_SHAPE,\n",
    "                         nb_frames=nb_frames,\n",
    "                         render_dir=render_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_dir = Path.home() / 'Documents/videos/gan' / \"gan_fmnist_test\"\n",
    "\n",
    "nb_transition_frames = 10\n",
    "\n",
    "# random list of z vectors\n",
    "#rand_idx = np.random.randint(len(X_train))\n",
    "z_start = np.random.randn(1, HIDDEN_DIM)\n",
    "vals = np.linspace(-1., 1., nb_transition_frames)\n",
    "\n",
    "for z_idx in range(20):\n",
    "    animate_latent_transition(latent_vectors=z_start, \n",
    "                             gen_image_fun=gen_image_fun,\n",
    "                             gen_latent_fun=lambda z_s, i: gen_latent_idx(z_s, i, z_idx, vals),\n",
    "                             img_size=PLOT_IMG_SHAPE,\n",
    "                             nb_frames=nb_transition_frames,\n",
    "                             render_dir=render_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Data Science",
   "language": "python",
   "name": "data-science"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "notify_time": "30",
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "592px",
    "left": "0px",
    "right": "1056.75px",
    "top": "87px",
    "width": "143.25px"
   },
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
